{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 模型定义及损失函数实现原理 \n",
    "\n",
    "1. 双线性插值方法定义\n",
    "2. FCN模型搭建\n",
    "3. 损失函数原理补充\n",
    "4. 损失函数计算过程"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "eGcpSd6QXFnW"
   },
   "outputs": [],
   "source": [
    "# encoding: utf-8\n",
    "import torch \n",
    "from torchvision import models\n",
    "from torch import nn\n",
    "import torch.nn.functional as F\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "9nHINmwPXJ3_"
   },
   "outputs": [],
   "source": [
    "def bilinear_kernel(in_channels, out_channels, kernel_size):\n",
    "    \"\"\"Define a bilinear kernel according to in channels and out channels.\n",
    "    Returns:\n",
    "        return a bilinear filter tensor\n",
    "    \"\"\"\n",
    "    factor = (kernel_size + 1) // 2\n",
    "    if kernel_size % 2 == 1:\n",
    "        center = factor - 1\n",
    "    else:\n",
    "        center = factor - 0.5\n",
    "    og = np.ogrid[:kernel_size, :kernel_size]\n",
    "    bilinear_filter = (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor)\n",
    "    weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), dtype=np.float32)\n",
    "    weight[range(in_channels), range(out_channels), :, :] = bilinear_filter\n",
    "    return torch.from_numpy(weight)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 790
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 6488,
     "status": "ok",
     "timestamp": 1586141683245,
     "user": {
      "displayName": "Rongfan Liao",
      "photoUrl": "",
      "userId": "07803922812103577726"
     },
     "user_tz": -480
    },
    "id": "X3P3Ds-GXffM",
    "outputId": "708f5133-7139-4a5c-d2f6-ea08e8d9f2f9",
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (2): ReLU(inplace=True)\n",
       "  (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (5): ReLU(inplace=True)\n",
       "  (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "  (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (9): ReLU(inplace=True)\n",
       "  (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (12): ReLU(inplace=True)\n",
       "  (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "  (14): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (15): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (16): ReLU(inplace=True)\n",
       "  (17): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (18): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (19): ReLU(inplace=True)\n",
       "  (20): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (21): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (22): ReLU(inplace=True)\n",
       "  (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "  (24): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (25): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (26): ReLU(inplace=True)\n",
       "  (27): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (28): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (29): ReLU(inplace=True)\n",
       "  (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (31): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (32): ReLU(inplace=True)\n",
       "  (33): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "  (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (35): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (36): ReLU(inplace=True)\n",
       "  (37): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (38): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (39): ReLU(inplace=True)\n",
       "  (40): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (41): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (42): ReLU(inplace=True)\n",
       "  (43): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       ")"
      ]
     },
     "execution_count": 3,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pretrained_net = models.vgg16_bn(pretrained=False) # 无需预训练权重，最好是开启，第一次会下载\n",
    "pretrained_net.features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 1031,
     "status": "ok",
     "timestamp": 1586090748901,
     "user": {
      "displayName": "Rongfan Liao",
      "photoUrl": "",
      "userId": "07803922812103577726"
     },
     "user_tz": -480
    },
    "id": "qWQ6ZaBfX_TW",
    "outputId": "3010e0df-4908-4b96-819e-599001df47dc"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))"
      ]
     },
     "execution_count": 9,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pretrained_net.features[0]  # vgg16网络第一层"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 1045,
     "status": "ok",
     "timestamp": 1586090832285,
     "user": {
      "displayName": "Rongfan Liao",
      "photoUrl": "",
      "userId": "07803922812103577726"
     },
     "user_tz": -480
    },
    "id": "FS79IOwhYLxF",
    "outputId": "5dcacea6-8cf9-45a0-9ac6-dae4da2cc157"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([64, 3, 3, 3])"
      ]
     },
     "execution_count": 12,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pretrained_net.features[0].weight.shape # 64个3*3*3的卷积核"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 168
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 1057,
     "status": "ok",
     "timestamp": 1586090974372,
     "user": {
      "displayName": "Rongfan Liao",
      "photoUrl": "",
      "userId": "07803922812103577726"
     },
     "user_tz": -480
    },
    "id": "vV9aR7ElYGT4",
    "outputId": "08079f72-7c78-4b0f-c195-dfa97014b2d8"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (2): ReLU(inplace=True)\n",
       "  (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (5): ReLU(inplace=True)\n",
       "  (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       ")"
      ]
     },
     "execution_count": 13,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pretrained_net.features[:7]  # vgg16网络第1-6层 【左闭右开】"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<img src=\"FCN.png\", width=800, heigth=800>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "%%html\n",
    "<img src=\"FCN.png\", width=800, heigth=800>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "r4gF4VaIWtyi"
   },
   "outputs": [],
   "source": [
    "\n",
    "class FCN(nn.Module):\n",
    "    def __init__(self, num_classes): # 传进分类数\n",
    "        super().__init__()\n",
    "\n",
    "        self.stage1 = pretrained_net.features[:7]\n",
    "        self.stage2 = pretrained_net.features[7:14]\n",
    "        self.stage3 = pretrained_net.features[14:24]\n",
    "        self.stage4 = pretrained_net.features[24:34]\n",
    "        self.stage5 = pretrained_net.features[34:]\n",
    "        \n",
    "        # 最终预测图\n",
    "        self.scores1 = nn.Conv2d(512, num_classes, 1) # 32s\n",
    "        self.scores2 = nn.Conv2d(512, num_classes, 1) # 16s\n",
    "        self.scores3 = nn.Conv2d(128, num_classes, 1) # 8s\n",
    "        # 过渡卷积，只改变通道数\n",
    "        self.conv_trans1 = nn.Conv2d(512, 256, 1)\n",
    "        self.conv_trans2 = nn.Conv2d(256, num_classes, 1)\n",
    "        # ConvTranspose2d(in_channels, out_channels, kernel_size, stride=1, padding=0, **args)\n",
    "        \n",
    "        # 反卷积，8倍还原\n",
    "        # step = 8\n",
    "        self.upsample_8x = nn.ConvTranspose2d(num_classes, num_classes, 16, 8, 4, bias=False)\n",
    "        # 双线插值反卷积核池化\n",
    "        self.upsample_8x.weight.data = bilinear_kernel(num_classes, num_classes, 16)\n",
    "        # 512*512通道数\n",
    "        self.upsample_2x_1 = nn.ConvTranspose2d(512, 512, 4, 2, 1, bias=False)\n",
    "        self.upsample_2x_1.weight.data = bilinear_kernel(512, 512, 4)\n",
    "\n",
    "        self.upsample_2x_2 = nn.ConvTranspose2d(256, 256, 4, 2, 1, bias=False)\n",
    "        self.upsample_2x_2.weight.data = bilinear_kernel(256, 256, 4)\n",
    "        \n",
    "        # 16与32反卷积扩张在FCN论文中明确效果不好，省略\n",
    "        # 若增加则也要改变scores1\n",
    "\n",
    "    def forward(self, x): # 原始图像 352*480*3\n",
    "        # 前向传播\n",
    "        s1 = self.stage1(x) # 176，240，64\n",
    "        s2 = self.stage2(s1)# 88，120，128\n",
    "        s3 = self.stage3(s2)# 44，60，128\n",
    "        s4 = self.stage4(s3)# 22，30，512\n",
    "        s5 = self.stage5(s4)# 11，15，512\n",
    "        # 通道数不建议扩大到512以上\n",
    "        \n",
    "        \n",
    "        # 还原上采样\n",
    "        \n",
    "        #分数图\n",
    "        scores1 = self.scores1(s5) # 11，15，12（类）\n",
    "        s5 = self.upsample_2x_1(s5) # 22，30，512\n",
    "        add1 = s5 + s4 # 跳跃连接  22，30，512\n",
    "\n",
    "        scores2 = self.scores2(add1)  # 22，30，12\n",
    "        \n",
    "        # 通道数改变，从512变为256，使用之前的过渡卷积\n",
    "        add1 = self.conv_trans1(add1)\n",
    "        # 反卷积扩大\n",
    "        add1 = self.upsample_2x_2(add1)# 44，60，256\n",
    "        add2 = add1 + s3 # 44，60，256\n",
    "        # 八倍还原前先用过渡卷积调整通道数\n",
    "        output = self.conv_trans2(add2) # 最终输出\n",
    "        output = self.upsample_8x(output)\n",
    "        return output # 352，480，12\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 521
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 1179,
     "status": "ok",
     "timestamp": 1586142357074,
     "user": {
      "displayName": "Rongfan Liao",
      "photoUrl": "",
      "userId": "07803922812103577726"
     },
     "user_tz": -480
    },
    "id": "fkxtVerlezwl",
    "outputId": "36bb2e8d-72da-43cd-bbea-765c3903d356"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[ 9,  3,  7,  ...,  9,  1,  3],\n",
      "         [ 2, 10, 11,  ...,  5, 10,  7],\n",
      "         [ 5,  3,  1,  ...,  6,  6,  1],\n",
      "         ...,\n",
      "         [ 7, 11, 10,  ...,  4,  7,  0],\n",
      "         [ 2,  6, 11,  ...,  3,  3,  5],\n",
      "         [ 1,  2, 11,  ..., 10,  5,  7]]])\n",
      "tensor([[[[-0.1696, -0.0032, -0.7078,  ...,  0.5222, -0.7567, -0.5255],\n",
      "          [-0.1247,  0.3469, -0.0461,  ...,  0.1995, -1.8711, -1.2577],\n",
      "          [ 0.8483,  0.0107,  0.1716,  ..., -0.8260,  0.7499, -2.0751],\n",
      "          ...,\n",
      "          [ 1.1645, -0.3760,  0.4654,  ...,  0.6341, -0.9732,  0.3170],\n",
      "          [ 0.6558, -1.7625, -0.0942,  ...,  1.1686,  2.0618, -1.5564],\n",
      "          [-0.0292, -0.0116,  0.2264,  ...,  0.1721,  0.1979,  0.1158]],\n",
      "\n",
      "         [[ 1.0275, -0.8438,  0.9060,  ..., -1.8316,  1.0276, -0.1227],\n",
      "          [ 0.2348, -0.0512,  0.7146,  ...,  1.1838,  0.0182, -0.3521],\n",
      "          [-0.8555,  0.5010, -0.3183,  ..., -1.7400, -1.7042, -0.9826],\n",
      "          ...,\n",
      "          [-0.7848,  0.3533,  0.0935,  ..., -1.0900,  0.4637, -0.5743],\n",
      "          [ 2.8712, -0.1722, -0.6189,  ...,  1.0843, -0.1296,  0.3218],\n",
      "          [-0.2604, -0.0880, -1.5435,  ...,  0.5058,  0.8597,  0.9267]],\n",
      "\n",
      "         [[-0.0117, -0.0831, -0.8518,  ...,  0.1229, -1.8050, -0.6117],\n",
      "          [-0.0720,  0.8733, -1.0352,  ..., -0.6158, -0.6748,  1.6856],\n",
      "          [ 0.6656,  1.0486,  0.3979,  ...,  0.2742, -0.4600,  0.8802],\n",
      "          ...,\n",
      "          [ 1.2418,  0.2354, -1.6631,  ..., -0.5167, -0.2216,  0.5159],\n",
      "          [-0.6694, -0.4663, -1.7774,  ..., -0.1500,  0.3772,  0.4272],\n",
      "          [ 0.1405,  0.5880, -0.2824,  ..., -0.4916, -1.1650, -1.1030]]]])\n"
     ]
    }
   ],
   "source": [
    "# fake label data 虚拟数据\n",
    "gt = np.random.rand(1, 352, 480)*12\n",
    "gt = gt.astype(np.int64)\n",
    "gt = torch.from_numpy(gt)\n",
    "print(gt)\n",
    "x = t.randn(1, 3, 352, 480)\n",
    "print(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 84
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 3438,
     "status": "ok",
     "timestamp": 1586142038276,
     "user": {
      "displayName": "Rongfan Liao",
      "photoUrl": "",
      "userId": "07803922812103577726"
     },
     "user_tz": -480
    },
    "id": "uObC1lrtaBeY",
    "outputId": "88779a8f-b1ce-45c6-9b2b-ff7351f61755"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 12, 352, 480])\n",
      "torch.Size([1, 12, 352, 480])\n",
      "torch.Size([1, 352, 480])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor(2.7314, grad_fn=<NllLoss2DBackward>)"
      ]
     },
     "execution_count": 19,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net = FCN(12) # 分类数12\n",
    "y = net(x)\n",
    "print(y.shape) # 整个网络预测图\n",
    "out = F.log_softmax(y, dim=1)\n",
    "print(out.shape)\n",
    "\n",
    "criterion = nn.NLLLoss() # 交叉熵\n",
    "# 获得gt\n",
    "print(gt.shape)\n",
    "loss = criterion(out, gt)\n",
    "loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 912,
     "status": "ok",
     "timestamp": 1586142041984,
     "user": {
      "displayName": "Rongfan Liao",
      "photoUrl": "",
      "userId": "07803922812103577726"
     },
     "user_tz": -480
    },
    "id": "Riiq1Q5yboLs",
    "outputId": "908abd3f-a7c5-4454-cfce-43a57e96a110"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2.7313966751098633"
      ]
     },
     "execution_count": 20,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loss.item()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "OLChToAoB5MT"
   },
   "source": [
    "**损失函数**\n",
    "\n",
    "　　损失的回传最好精细到每个像素上，具体的呈现一下数值是怎么来的"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<img src=\"OneHot2.png\", width=500, heigth=500>\n",
       "<img src=\"OneHot.png\", width=600, heigth=600>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "%%html\n",
    "<img src=\"OneHot2.png\", width=500, heigth=500>\n",
    "<img src=\"OneHot.png\", width=600, heigth=600>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "4LDWqdIrMi1t"
   },
   "source": [
    "**NLLLoss**\n",
    "\n",
    "　　这个损失函数的计算可以表达为：`loss(input, class) = -input[class]`。举例说明，三分类任务，输入`input=[-1.233, 2.657, 0.534]`， 真实标签类别`class=2`，则`loss=-0.534`，就是在对应类别的输出上取一个负号。实际应用： 常用于多分类任务，但是input在输入NLLLoss()之前，需要对input进行log_softmax函数激活，即将input转换成概率分布的形式，并且取对数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 252
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 1026,
     "status": "ok",
     "timestamp": 1586155757500,
     "user": {
      "displayName": "Rongfan Liao",
      "photoUrl": "",
      "userId": "07803922812103577726"
     },
     "user_tz": -480
    },
    "id": "TvIcUhxoCcqp",
    "outputId": "6169f79f-2757-404c-9608-886d7a9e3ef2"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[0, 1, 1],\n",
      "         [0, 0, 0]]])\n",
      "========================================\n",
      "tensor([[[[ 0.0655, -0.4263,  0.2459],\n",
      "          [ 0.8114,  0.3539,  0.0906]],\n",
      "\n",
      "         [[-1.4042,  0.2177,  0.2541],\n",
      "          [-0.9366,  0.1628, -0.4252]]]])\n",
      "----------------------------------------\n",
      "tensor([[[[-0.2070, -1.0661, -0.6972],\n",
      "          [-0.1605, -0.6022, -0.4681]],\n",
      "\n",
      "         [[-1.6767, -0.4221, -0.6891],\n",
      "          [-1.9085, -0.7933, -0.9839]]]])\n"
     ]
    }
   ],
   "source": [
    "# fake label data\n",
    "gt = np.random.rand(1, 2, 3)*2\n",
    "gt = gt.astype(np.int64)\n",
    "gt = torch.from_numpy(gt)\n",
    "\n",
    "x = torch.randn(1, 2, 2, 3)\n",
    "out = F.log_softmax(x, dim=1)\n",
    "\n",
    "print(gt)\n",
    "print('='*40)\n",
    "print(x)\n",
    "print('-'*40)\n",
    "print(out)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "WlFwxUPWVq92"
   },
   "source": [
    "　　直白而言是按标签给出的像素点的类别，去每个像素点对应通道上找相应的值算到损失里。为什么可以这样呢，因为如果该点被正确分类，即在这个点的特征向量相对应类别的位置为1，其他位置的值为0，则经log_softmax计算后，该位置的值为0，即正确分类的损失为0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 50
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 1195,
     "status": "ok",
     "timestamp": 1586155762814,
     "user": {
      "displayName": "Rongfan Liao",
      "photoUrl": "",
      "userId": "07803922812103577726"
     },
     "user_tz": -480
    },
    "id": "nP1ShDWNmOiw",
    "outputId": "3d1ad227-50e9-45b5-be66-9490c670a0a5"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[0.2070, 0.4221, 0.6891],\n",
       "         [0.1605, 0.6022, 0.4681]]])"
      ]
     },
     "execution_count": 55,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# gt = tensor([[[0, 1, 1],   \n",
    "#         [0, 0, 0]]]) \n",
    "# gt:dim(batch, w, h)\n",
    "\n",
    "# out = tensor([[[[-0.2070, -1.0661, -0.6972],   \n",
    "#          [-0.1605, -0.6022, -0.4681]],\n",
    "#\n",
    "#          [[-1.6767, -0.4221, -0.6891],\n",
    "#          [-1.9085, -0.7933, -0.9839]]]])\n",
    "# out:dim(Batch, channel, w, h)\n",
    "\n",
    "criterion = nn.NLLLoss(reduction='none') # default reduction='mean'\n",
    "loss = criterion(out, gt)\n",
    "loss\n",
    "\n",
    "# loss = tensor([[[0.2070, 0.4221, 0.6891],\n",
    "#          [0.1605, 0.6022, 0.4681]]])\n",
    "# loss:dim(batch, w, h)\n",
    "\n",
    "# loss[0][0][0] = 0.2070 来自 -out[0][i][0][0]  i=gt[0][0][0]=0 \n",
    "# loss[0][1][0] = 0.1605 来自 -out[0][i][1][0]  i=gt[0][1][0]=0 \n",
    "# loss[0][0][1] = 0.4221 来自 -out[0][i][0][1]  i=gt[0][0][1]=1 \n",
    "\n",
    "# (0.2070 + 0.4221 + 0.6891 + 0.1605 + 0.6022 + 0.4681) / 6 = 0.4248\n",
    "\n",
    "# criterion = nn.NLLLoss() # default reduction='mean'\n",
    "# loss = criterion(out, gt) = 0.4248\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 1454,
     "status": "ok",
     "timestamp": 1586157082016,
     "user": {
      "displayName": "Rongfan Liao",
      "photoUrl": "",
      "userId": "07803922812103577726"
     },
     "user_tz": -480
    },
    "id": "C2aKjVyNmtbs",
    "outputId": "b8288a25-5ab5-4e37-c763-59b726aabfbd"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.4248333333333334"
      ]
     },
     "execution_count": 69,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(0.2070 + 0.4221 + 0.6891 + 0.1605 + 0.6022 + 0.4681) / 6"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 1217,
     "status": "ok",
     "timestamp": 1586157084223,
     "user": {
      "displayName": "Rongfan Liao",
      "photoUrl": "",
      "userId": "07803922812103577726"
     },
     "user_tz": -480
    },
    "id": "m2tuvp_CB9wn",
    "outputId": "55eb7a74-269e-42d7-a3c4-f26a64351488"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.4248)"
      ]
     },
     "execution_count": 70,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "criterion = nn.NLLLoss()\n",
    "loss = criterion(out, gt)\n",
    "loss"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "authorship_tag": "ABX9TyPDRiwhK56W3Y6GaXsAVKX9",
   "collapsed_sections": [],
   "name": "model and loss.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
